{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270a43c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ea1c0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "import torch\n",
    "from whisperspeech.t2s_up_wds_mlang_enclm import TSARTransformer\n",
    "from whisperspeech.s2a_delar_mup_wds_mlang import SADelARTransformer\n",
    "from whisperspeech.a2wav import Vocoder\n",
    "import traceback\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "502ea753",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Pipeline:\n",
    "    default_speaker = torch.tensor(\n",
    "       [-0.2929, -0.4503,  0.4155, -0.1417,  0.0473, -0.1624, -0.2322,  0.7071,\n",
    "         0.4800,  0.5496,  0.0410,  0.6236,  0.4729,  0.0587,  0.2194, -0.0466,\n",
    "        -0.3036,  0.0497,  0.5028, -0.1703,  0.5039, -0.6464,  0.3857, -0.7350,\n",
    "        -0.1605,  0.4808,  0.5397, -0.4851,  0.1774, -0.8712,  0.5789,  0.1785,\n",
    "        -0.1417,  0.3039,  0.4232, -0.0186,  0.2685,  0.6153, -0.3103, -0.5706,\n",
    "        -0.4494,  0.3394, -0.6184, -0.3617,  1.1041, -0.1178, -0.1885,  0.1997,\n",
    "         0.5571, -0.2906, -0.0477, -0.4048, -0.1062,  1.4779,  0.1639, -0.3712,\n",
    "        -0.1776, -0.0568, -0.6162,  0.0110, -0.0207, -0.1319, -0.3854,  0.7248,\n",
    "         0.0343,  0.5724,  0.0670,  0.0486, -0.3813,  0.1738,  0.3017,  1.0502,\n",
    "         0.1550,  0.5708,  0.0366,  0.5093,  0.0294, -0.7091, -0.8220, -0.1583,\n",
    "        -0.2343,  0.1366,  0.7372, -0.0631,  0.1505,  0.4600, -0.1252, -0.5245,\n",
    "         0.7523, -0.0386, -0.2587,  1.0066, -0.2037,  0.1617, -0.3800,  0.2790,\n",
    "         0.0184, -0.5111, -0.7291,  0.1627,  0.2367, -0.0192,  0.4822, -0.4458,\n",
    "         0.1457, -0.5884,  0.1909,  0.2563, -0.2035, -0.0377,  0.7771,  0.2139,\n",
    "         0.3801,  0.6047, -0.6043, -0.2563, -0.0726,  0.3856,  0.3217,  0.0823,\n",
    "        -0.1302,  0.3287,  0.5693,  0.2453,  0.8231,  0.0072,  1.0327,  0.6065,\n",
    "        -0.0620, -0.5572,  0.5220,  0.2485,  0.1520,  0.0222, -0.2179, -0.7392,\n",
    "        -0.3855,  0.1822,  0.1042,  0.7133,  0.3583,  0.0606, -0.0424, -0.9189,\n",
    "        -0.4882, -0.5480, -0.5719, -0.1660, -0.3439, -0.5814, -0.2542,  0.0197,\n",
    "         0.4942,  0.0915, -0.0420, -0.0035,  0.5578,  0.1051, -0.0891,  0.2348,\n",
    "         0.6876, -0.6685,  0.8215, -0.3692, -0.3150, -0.0462, -0.6806, -0.2661,\n",
    "        -0.0308, -0.0050,  0.6756, -0.1647,  1.0734,  0.0049,  0.4969,  0.0259,\n",
    "        -0.8949,  0.0731,  0.0886,  0.3442, -0.1433, -0.6804,  0.2204,  0.1859,\n",
    "         0.2702,  0.1699, -0.1443, -0.9614,  0.3261,  0.1718,  0.3545, -0.0686]\n",
    "    )\n",
    "    \n",
    "    def __init__(self, t2s_ref=None, s2a_ref=None, optimize=True, torch_compile=False):\n",
    "        args = dict()\n",
    "        try:\n",
    "            if t2s_ref:\n",
    "                args[\"ref\"] = t2s_ref\n",
    "            self.t2s = TSARTransformer.load_model(**args).cuda()\n",
    "            if optimize: self.t2s.optimize(torch_compile=torch_compile)\n",
    "        except:\n",
    "            print(\"Failed to load the T2S model:\")\n",
    "            print(traceback.format_exc())\n",
    "        try:\n",
    "            if s2a_ref:\n",
    "                args[\"ref\"] = s2a_ref\n",
    "            self.s2a = SADelARTransformer.load_model(**args).cuda()\n",
    "            if optimize: self.s2a.optimize(torch_compile=torch_compile)\n",
    "        except:\n",
    "            print(\"Failed to load the S2A model:\")\n",
    "            print(traceback.format_exc())\n",
    "        self.vocoder = Vocoder()\n",
    "        self.encoder = None\n",
    "\n",
    "    def extract_spk_emb(self, fname):\n",
    "        \"\"\"Extracts a speaker embedding from the first 30 seconds of the give audio file.\n",
    "        \"\"\"\n",
    "        import torchaudio\n",
    "        if self.encoder is None:\n",
    "            from speechbrain.pretrained import EncoderClassifier\n",
    "            self.encoder = EncoderClassifier.from_hparams(\"speechbrain/spkrec-ecapa-voxceleb\",\n",
    "                                                          savedir=\"~/.cache/speechbrain/\",\n",
    "                                                          run_opts={\"device\": \"cuda\"})\n",
    "        samples, sr = torchaudio.load(fname)\n",
    "        samples = self.encoder.audio_normalizer(samples[0,:30*sr], sr)\n",
    "        spk_emb = self.encoder.encode_batch(samples)\n",
    "        return spk_emb[0,0]\n",
    "        \n",
    "    def generate_atoks(self, text, speaker=None, lang='en', cps=15, step_callback=None):\n",
    "        if speaker is None: speaker = self.default_speaker\n",
    "        elif isinstance(speaker, (str, Path)): speaker = self.extract_spk_emb(speaker)\n",
    "        text = text.replace(\"\\n\", \" \")\n",
    "        stoks = self.t2s.generate(text, cps=cps, lang=lang, step=step_callback)\n",
    "        atoks = self.s2a.generate(stoks, speaker.unsqueeze(0), step=step_callback)\n",
    "        return atoks\n",
    "        \n",
    "    def generate(self, text, speaker=None, lang='en', cps=15, step_callback=None):\n",
    "        return self.vocoder.decode(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=step_callback))\n",
    "    \n",
    "    def generate_to_file(self, fname, text, speaker=None, lang='en', cps=15, step_callback=None):\n",
    "        self.vocoder.decode_to_file(fname, self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))\n",
    "        \n",
    "    def generate_to_notebook(self, text, speaker=None, lang='en', cps=15, step_callback=None):\n",
    "        self.vocoder.decode_to_notebook(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
